# ----------------------------------------------------------------------------
# SymForce - Copyright 2022, Skydio, Inc.
# This source code is under the Apache 2.0 license found in the LICENSE file.
# ----------------------------------------------------------------------------

import re
import sys
import textwrap
from pathlib import Path

import pybind11_stubgen

# Needed so sys.modules["cc_sym"] exists
from symforce import cc_sym  # noqa: F401
from symforce import path_util
from symforce.codegen import format_util
from symforce.test_util import TestCase
from symforce.test_util.stubs_util import patch_handle_docstring
from symforce.test_util.stubs_util import patch_lcmtype_imports
from symforce.test_util.stubs_util import patch_remove_parameters

patch_lcmtype_imports()
patch_handle_docstring()
patch_remove_parameters()


class SymforceCCSymStubsCodegenTest(TestCase):
    def cc_sym_stubgen_output(self) -> str:
        """
        Returns the contents of the stub file produced by pybind11-stubgen on module cc_sym
        """
        output_dir = self.make_output_dir("sf_cc_sym_stubgen_output")

        cc_sym_path = sys.modules["cc_sym"].__file__
        assert cc_sym_path is not None

        import logging

        logging.basicConfig(
            level=logging.INFO,
            format="%(name)s - [%(levelname)7s] %(message)s",
        )

        cli_args = [
            "cc_sym",
            # These are correct, but our numpy/mypy isn't ready yet
            "--numpy-array-remove-parameters",
            # Ignore undefined n and m typevars for numpy arrays
            # This shouldn't be needed with --numpy-array-remove-parameters, but it seems like it is
            "--ignore-unresolved-names=n|m",
            f"--output-dir={output_dir}",
            "--exit-code",
        ]

        args = pybind11_stubgen.arg_parser().parse_args(
            cli_args, namespace=pybind11_stubgen.CLIArgs()
        )

        parser = pybind11_stubgen.stub_parser_from_args(args)
        printer = pybind11_stubgen.Printer(invalid_expr_as_ellipses=False)

        out_dir, sub_dir = pybind11_stubgen.to_output_and_subdir(
            output_dir=args.output_dir,
            module_name=args.module_name,
            root_suffix=args.root_suffix,
        )

        pybind11_stubgen.run(
            parser,
            printer,
            args.module_name,
            out_dir,
            sub_dir=sub_dir,
            dry_run=args.dry_run,
            writer=pybind11_stubgen.Writer(stub_ext=args.stub_extension),
        )

        generated_file = output_dir / "cc_sym.pyi"

        return generated_file.read_text()

    def test_generate_cc_sym_stubs(self) -> None:
        output_dir = self.make_output_dir("sf_cc_sym_stubs_codegen_test")

        stubgen_output = self.cc_sym_stubgen_output()

        # Change type of OptimizationStats.best_linearization to be Optional[Linearization]
        stubgen_output = re.sub(
            r"def best_linearization\(self\) -> typing.Any",
            "def best_linearization(self) -> typing.Optional[Linearization]",
            stubgen_output,
        )

        stubgen_output = textwrap.dedent(
            """
            # -----------------------------------------------------------------------------
            # This file was autogenerated by symforce by:
            #     symforce_cc_sym_stubs_codegen_test
            # Do NOT modify by hand.
            # -----------------------------------------------------------------------------

            {stubgen_output}
            """
        ).format(stubgen_output=stubgen_output)

        stubgen_output = format_util.format_py(
            stubgen_output, str(Path(__file__).parent / "cc_sym.pyi")
        )

        (output_dir / "cc_sym.pyi").write_text(stubgen_output)

        self.compare_or_update_file(
            new_file=output_dir / "cc_sym.pyi",
            path=path_util.symforce_data_root(__file__) / "symforce" / "pybind" / "cc_sym.pyi",
        )


if __name__ == "__main__":
    TestCase.main()
